#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date  : 2025/8/27
# @File  : train_outline_agent.py
# @Author: johnson
# @Desc  : 训练一个“遍历大纲→搜索→填充text→保持原格式返回”的 ReAct Agent（ART + LangGraph，GRPO）
import logging
logging.basicConfig(level=logging.DEBUG)
import os
import re
import json
import uuid
import time
import asyncio
import dotenv
import wandb
import urllib.parse
from dataclasses import dataclass
from statistics import mean
from textwrap import dedent
from typing import List, Dict, Any, Optional, Tuple

import art
from art.langgraph import init_chat_model, wrap_rollout
from art.utils import iterate_dataset
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from pydantic import BaseModel, Field, ValidationError
from tenacity import retry, stop_after_attempt
from reward import search_reward,format_reward

# ==== 你的 Web 搜索客户端（示例：ZhipuAiClient）====
from zai import ZhipuAiClient
dotenv.load_dotenv()
WebSearchClient = ZhipuAiClient(api_key=os.environ["ZHIPU_API_KEY"])

# ---------------- wandb: 运行配置 ----------------
NAME = os.getenv("ART_NAME", "outline-webfill")
MODEL_NAME = os.getenv("ART_MODEL", "Qwen/Qwen2.5-0.5B-Instruct")
PROJECT_NAME = os.getenv("ART_PROJECT", "content-training")
USE_LOCAL_BACKEND = os.getenv("ART_BACKEND", "local").lower() == "local"

WANDB_PROJECT = os.getenv("WANDB_PROJECT", PROJECT_NAME)
WANDB_ENTITY = os.getenv("WANDB_ENTITY")  # 可空
WANDB_RUN_NAME = os.getenv("WANDB_RUN_NAME", f"{NAME}-{time.strftime('%Y%m%d-%H%M%S')}")

print(f"{NAME} - {MODEL_NAME} - {PROJECT_NAME} - {os.environ['WANDB_BASE_URL']}")

# ----------------- 数据结构 -----------------
class WebSearchResult(BaseModel):
    url: str
    title: str
    snippet: str

class FinalFilledOutline(BaseModel):
    task: List[Dict[str, Any]]                 # 回传的“保持原格式”的 task
    sources: List[str] = Field(default_factory=list)  # 用于 [1],[2]... 的 URL 列表

@dataclass
class OutlineScenario:
    id: str
    prompt: str               # 给 Agent 的用户消息（含说明与 JSON）
    input_task: List[Dict[str, Any]]  # 原始 task 结构，用于奖励对比

class ProjectTrajectory(art.Trajectory):
    final: Optional[FinalFilledOutline] = None

# ----------------- 训练数据 -----------------
train_data = [
      {
        "difficulty": 1,
        "task": [
          {"type": "cover", "data": {"title": "2025年市场趋势与增长机会", "text": "A presentation generated by AI"}},
          {"type": "contents", "data": {"items": ["数字化与人工智能驱动", "消费升级与新兴业态", "产业创新与转型升级", "渠道变革与市场拓展", "政策红利与投资机会"]}},
          {"type": "transition", "data": {"title": "数字化与人工智能驱动", "text": "Exploring the topic of 数字化与人工智能驱动"}},
          {"type": "content", "data": {"title": "人工智能技术应用", "items": [
            {"title": "部署AI辅助诊断系统", "text": "Detailed content about 部署AI辅助诊断系统"},
            {"title": "开发智能推荐系统", "text": "Detailed content about 开发智能推荐系统"}
          ]}},
          {"type": "content", "data": {"title": "数字化转型加速", "items": [
            {"title": "推动设备加装智能传感器", "text": "Detailed content about 推动设备加装智能传感器"},
            {"title": "推广算力资源普惠服务", "text": "Detailed content about 推广算力资源普惠服务"}
          ]}},
          {"type": "transition", "data": {"title": "消费升级与新兴业态", "text": "Exploring the topic of 消费升级与新兴业态"}},
          {"type": "content", "data": {"title": "懒人经济崛起", "items": [
            {"title": "推出一键解决智能产品", "text": "Detailed content about 推出一键解决智能产品"},
            {"title": "提供极致便捷服务体验", "text": "Detailed content about 提供极致便捷服务体验"}
          ]}},
          {"type": "content", "data": {"title": "体验经济爆发", "items": [
            {"title": "打造沉浸式旗舰店", "text": "Detailed content about 打造沉浸式旗舰店"},
            {"title": "创建展览餐饮零售一体空间", "text": "Detailed content about 创建展览餐饮零售一体空间"}
          ]}},
          {"type": "transition", "data": {"title": "产业创新与转型升级", "text": "Exploring the topic of 产业创新与转型升级"}},
          {"type": "content", "data": {"title": "新能源汽车发展", "items": [
            {"title": "实现渗透率超50%", "text": "Detailed content about 实现渗透率超50%"},
            {"title": "开发超快充技术", "text": "Detailed content about 开发超快充技术"}
          ]}},
          {"type": "content", "data": {"title": "智能制造推进", "items": [
            {"title": "应用工业机器人", "text": "Detailed content about 应用工业机器人"},
            {"title": "实现人机物协同", "text": "Detailed content about 实现人机物协同"}
          ]}},
          {"type": "transition", "data": {"title": "渠道变革与市场拓展", "text": "Exploring the topic of 渠道变革与市场拓展"}},
          {"type": "content", "data": {"title": "线上线下融合", "items": [
            {"title": "优化私域流量运营", "text": "Detailed content about 优化私域流量运营"},
            {"title": "部署AR虚拟试穿", "text": "Detailed content about 部署AR虚拟试穿"}
          ]}},
          {"type": "content", "data": {"title": "跨境与出海布局", "items": [
            {"title": "利用海南自贸港政策", "text": "Detailed content about 利用海南自贸港政策"},
            {"title": "建立区域配送中心", "text": "Detailed content about 建立区域配送中心"}
          ]}},
          {"type": "transition", "data": {"title": "政策红利与投资机会", "text": "Exploring the topic of 政策红利与投资机会"}},
          {"type": "content", "data": {"title": "政府支持政策", "items": [
            {"title": "享受设备更新资金支持", "text": "Detailed content about 享受设备更新资金支持"},
            {"title": "获取模型券补贴", "text": "Detailed content about 获取模型券补贴"}
          ]}},
          {"type": "content", "data": {"title": "新兴投资领域", "items": [
            {"title": "投资人工智能研发项目", "text": "Detailed content about 投资人工智能研发项目"},
            {"title": "关注沉浸式体验场景", "text": "Detailed content about 关注沉浸式体验场景"}
          ]}},
          {"type": "end"}
        ]
      },
      {
        "difficulty": 1,
        "task": [
          {"type": "cover", "data": {"title": "新产品发布会：从概念到落地", "text": "A presentation generated by AI"}},
          {"type": "contents", "data": {"items": ["前期策划与概念设计", "内容策划与产品展示", "场地选择与现场布置", "执行管理与现场运营", "后期跟进与效果评估"]}},
          {"type": "transition", "data": {"title": "前期策划与概念设计", "text": "Exploring the topic of 前期策划与概念设计"}},
          {"type": "content", "data": {"title": "发布会定位与目标设定", "items": [
            {"title": "明确核心目的", "text": "Detailed content about 明确核心目的"},
            {"title": "确定目标受众", "text": "Detailed content about 确定目标受众"}
          ]}},
          {"type": "content", "data": {"title": "主题创意与概念设计", "items": [
            {"title": "构思吸引力主题", "text": "Detailed content about 构思吸引力主题"},
            {"title": "设计统一视觉系统", "text": "Detailed content about 设计统一视觉系统"}
          ]}},
          {"type": "transition", "data": {"title": "内容策划与产品展示", "text": "Exploring the topic of 内容策划与产品展示"}},
          {"type": "content", "data": {"title": "产品核心价值提炼", "items": [
            {"title": "挖掘独特卖点", "text": "Detailed content about 挖掘独特卖点"},
            {"title": "设计简洁介绍", "text": "Detailed content about 设计简洁介绍"}
          ]}},
          {"type": "content", "data": {"title": "演讲内容与嘉宾安排", "items": [
            {"title": "策划主题演讲", "text": "Detailed content about 策划主题演讲"},
            {"title": "邀请行业专家", "text": "Detailed content about 邀请行业专家"}
          ]}},
          {"type": "transition", "data": {"title": "场地选择与现场布置", "text": "Exploring the topic of 场地选择与现场布置"}},
          {"type": "content", "data": {"title": "场地评估与选择", "items": [
            {"title": "考察酒店或场馆", "text": "Detailed content about 考察酒店或场馆"},
            {"title": "评估交通便利性", "text": "Detailed content about 评估交通便利性"}
          ]}},
          {"type": "content", "data": {"title": "舞台设计与搭建", "items": [
            {"title": "设计舞台布局", "text": "Detailed content about 设计舞台布局"},
            {"title": "安排灯光音响", "text": "Detailed content about 安排灯光音响"}
          ]}},
          {"type": "transition", "data": {"title": "执行管理与现场运营", "text": "Exploring the topic of 执行管理与现场运营"}},
          {"type": "content", "data": {"title": "团队分工与协作", "items": [
            {"title": "明确职责分工", "text": "Detailed content about 明确职责分工"},
            {"title": "建立沟通机制", "text": "Detailed content about 建立沟通机制"}
          ]}},
          {"type": "content", "data": {"title": "现场流程控制", "items": [
            {"title": "制定时间流程表", "text": "Detailed content about 制定时间流程表"},
            {"title": "处理突发状况", "text": "Detailed content about 处理突发状况"}
          ]}},
          {"type": "transition", "data": {"title": "后期跟进与效果评估", "text": "Exploring the topic of 后期跟进与效果评估"}},
          {"type": "content", "data": {"title": "媒体传播与舆情监测", "items": [
            {"title": "跟踪媒体报道", "text": "Detailed content about 跟踪媒体报道"},
            {"title": "监测社交媒体", "text": "Detailed content about 监测社交媒体"}
          ]}},
          {"type": "content", "data": {"title": "经验总结与优化改进", "items": [
            {"title": "召开复盘会议", "text": "Detailed content about 召开复盘会议"},
            {"title": "制定改进措施", "text": "Detailed content about 制定改进措施"}
          ]}},
          {"type": "end"}
        ]
      }
    ]

# ----------------- 工具：web 搜索 -----------------
async def search_web(keyword: str) -> List[WebSearchResult]:
    response = WebSearchClient.web_search.web_search(
        search_engine="search_std",
        search_query=keyword,
        count=4,
        search_recency_filter="noLimit",
        content_size="high"
    )
    if not response.search_result:
        return []

    return [
        WebSearchResult(
            url=sr.link,
            title=sr.title,
            snippet=sr.content
        )
        for sr in response.search_result
    ]

# ----------------- rollout（核心）：LangGraph + Tools -----------------
async def rollout(model: art.Model, scenario: OutlineScenario) -> ProjectTrajectory:
    MAX_TURNS = 16
    traj = ProjectTrajectory(
        reward=0.0,
        messages_and_choices=[],
        metadata={"scenario_id": scenario.id}
    )

    # ====== 提示词：明确“遍历→搜索→填充→保持格式→引用” ======
    system_prompt = dedent("""
    你是一个结构化写作助手。你的任务：
    1) 遍历用户给定的 presentation 大纲（task 数组）；
    2) 针对每个需要填充的 text，利用 web_search_tool 生成精准查询（包含主题、地域、时间点、产业/政策关键词），搜索并阅读多条结果；
    3) 产出 2~4 句话的扎实文本，包含具体数字/时间/公司/政策/产品名等可核验事实，并在句尾添加来源引用 [n]（n 对应 sources 中 URL 的编号，从 1 开始）；
    4) 保持原 task 的结构完全一致（type、data.title、items 和 item.title 不得变更）。仅“text”允许替换与扩写；
    5) 完成后调用 return_filled_outline_tool(task, sources)：
       - task：与输入结构一致的数组；
       - sources：去重后的 URL 列表，顺序与 [n] 对应；
    6) 不要在普通对话中粘贴 JSON，务必用工具返回最终 JSON。
    """)

    # ====== 工具定义 ======
    @tool
    async def web_search_tool(query: str) -> List[dict]:
        """通过关键词查询网页，返回[{url,title,snippet}, ...]。"""
        results = await search_web(query)
        return [r.model_dump() for r in results]

    final: Optional[FinalFilledOutline] = None

    @tool
    def return_filled_outline_tool(task: List[Dict[str, Any]], sources: List[str]) -> dict:
        """返回最终 JSON：保持原格式的 task 与 sources。"""
        nonlocal final
        try:
            final = FinalFilledOutline(task=task, sources=sources or [])
            return final.model_dump()
        except ValidationError as e:
            # 让 Agent 收到错误并自我修复
            return {"error": f"ValidationError: {str(e)}"}

    tools = [web_search_tool, return_filled_outline_tool]

    chat_model = init_chat_model(MODEL_NAME, temperature=0.8)
    agent = create_react_agent(chat_model, tools)

    # ====== 执行 Agent ======
    # 将输入 task 放进用户消息中（JSON 格式，避免丢结构）
    user_msg = dedent(f"""
    请严格按系统要求处理以下 JSON 任务大纲（只替换 text）：
    {json.dumps(scenario.input_task, ensure_ascii=False)}
    """)

    await agent.ainvoke(
        {"messages": [SystemMessage(content=system_prompt),
                      HumanMessage(content=user_msg)]},
        config={"configurable": {"thread_id": str(uuid.uuid4())},
                "recursion_limit": MAX_TURNS},
    )

    # ====== 计算奖励 ======
    if final:
        traj.final = final

        # 从 messages_and_choices 中抽取工具返回的 URL，做“来源来自工具”一致性校验
        tool_urls_seen: List[str] = []
        try:
            for m in traj.messages_and_choices:
                if m.get("role") == "tool" and m.get("name") == "web_search_tool":
                    # content 里是工具返回列表
                    content = m.get("content")
                    if isinstance(content, str):
                        try:
                            content = json.loads(content)
                        except Exception:
                            content = None
                    if isinstance(content, list):
                        for rec in content:
                            u = (rec or {}).get("url")
                            if isinstance(u, str):
                                tool_urls_seen.append(u)
        except Exception:
            pass

        fr = format_reward(scenario.input_task, final.task)
        sr = search_reward(final.task, final.sources, tool_urls_seen=tool_urls_seen)

        traj.reward = 0.5 * fr + 0.5 * sr
        traj.metrics["format_reward"] = fr
        traj.metrics["search_reward"] = sr
        traj.metrics["sources_count"] = len(set(final.sources))

    return traj

# ----------------- wandb 记录 -----------------
def _log_batch_to_wandb(*, batch, finished_groups):
    trajectories = []
    for g in finished_groups:
        if hasattr(g, "trajectories"):
            trajectories.extend(getattr(g, "trajectories"))
        else:
            try:
                trajectories.extend(list(g))
            except Exception:
                pass

    table = wandb.Table(columns=["scenario_id", "format_reward", "search_reward", "total_reward", "sources"])
    for t in trajectories[:50]:
        sid = (getattr(t, "metadata", {}) or {}).get("scenario_id", "")
        fr = (getattr(t, "metrics", {}) or {}).get("format_reward", 0.0)
        sr = (getattr(t, "metrics", {}) or {}).get("search_reward", 0.0)
        rw = getattr(t, "reward", 0.0)
        srcs = ", ".join(getattr(getattr(t, "final", None), "sources", []) or [])
        table.add_data(sid, fr, sr, rw, srcs)

    wandb.log({
        "train/step": batch.step,
        "train/epoch": batch.epoch,
        "samples/trajectories": table
    }, step=batch.step)

# ----------------- 构造训练集 -----------------
def build_scenarios_from_train_data(td: List[Dict[str, Any]]) -> List[OutlineScenario]:
    scenarios: List[OutlineScenario] = []
    for i, row in enumerate(td, start=1):
        task = row.get("task", [])
        sid = f"doc_{i}"
        prompt = "填充大纲 json（仅替换 text，加入 [n] 引用并返回 sources）。"
        scenarios.append(OutlineScenario(id=sid, prompt=prompt, input_task=task))
    return scenarios

# ----------------- 主训练循环 -----------------
async def main():
    # wandb
    wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY if WANDB_ENTITY else None,
        name=WANDB_RUN_NAME,
        config={
            "art_project": PROJECT_NAME,
            "art_name": NAME,
            "base_model": MODEL_NAME,
            "backend": "local" if USE_LOCAL_BACKEND else "skypilot",
        },
        settings=wandb.Settings(start_method="thread"),
    )
    wandb.define_metric("*", step_metric="train/step")

    # Backend
    if USE_LOCAL_BACKEND:
        from art.local.backend import LocalBackend
        backend = LocalBackend()
    else:
        from art.skypilot.backend import SkyPilotBackend
        backend = await SkyPilotBackend.initialize_cluster(
            cluster_name=os.getenv("ART_SKYPILOT_CLUSTER", "art-cluster"),
            gpu=os.getenv("ART_GPU", "A100"),
        )

    model = art.TrainableModel(name=NAME, project=PROJECT_NAME, base_model=MODEL_NAME)
    await model.register(backend)

    scenarios = build_scenarios_from_train_data(train_data)

    training_config = {
        "groups_per_step": 2,
        "num_epochs": int(os.environ.get("NUM_EPOCHS", "2")),
        "rollouts_per_group": 3,
        "learning_rate": 1e-5,
        "max_steps": 6,
    }
    wandb.config.update(training_config)

    # 数据迭代器
    it = iterate_dataset(
        scenarios,
        groups_per_step=training_config["groups_per_step"],
        num_epochs=training_config["num_epochs"],
        initial_step=await model.get_step(),
    )

    for batch in it:
        print(f"[train] step={batch.step} epoch={batch.epoch}")

        # 组装 TrajectoryGroup：每个样本 rollout 多条轨迹
        groups = []
        for s in batch.items:
            groups.append(
                art.TrajectoryGroup(
                    wrap_rollout(model, rollout)(model, s)
                    for _ in range(training_config["rollouts_per_group"])
                )
            )

        # 收集轨迹
        finished = await art.gather_trajectory_groups(
            groups, pbar_desc="gather",
            max_exceptions=training_config["rollouts_per_group"] * len(batch.items),
        )

        _log_batch_to_wandb(batch=batch, finished_groups=finished)

        # 直接用我们在 rollout 里写入的 reward 做 GRPO
        await model.train(
            finished,
            config=art.TrainConfig(learning_rate=training_config["learning_rate"]),
        )

        if batch.step >= training_config["max_steps"]:
            break

    wandb.finish()

if __name__ == "__main__":
    asyncio.run(main())
